# -*- coding: utf-8 -*-
import json
import os
import string
import zipfile
import traceback
from pathlib import Path
import requests
from obs import ObsClient
from urllib.parse import unquote_plus
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from ges_graph import parse_csv_to_graph, clear

requests.packages.urllib3.disable_warnings(InsecureRequestWarning)

LOCAL_MOUNT_PATH = '/tmp/'
LETTERS = string.ascii_letters


def handler(event, context):
    log = context.getLogger()
    ak = context.getAccessKey()
    sk = context.getSecretKey()
    if ak == "" or sk == "":
        log.error('ak or sk is empty. Please set an agency.')
        return 'ak or sk is empty. Please set an agency.'

    obs_endpoint = context.getUserData('obs_endpoint')
    if not obs_endpoint:
        return 'obs_endpoint is not configured'

    target_bucket = context.getUserData('target_bucket')
    if not target_bucket:
        return 'target_bucket is not configured'

    graph_handler = GraphHandler(context)
    try:
        result = graph_handler.run(event['Records'][0])
        return result
    except:
        exec_info = traceback.format_exc()
        log.error(f"failed to run extract handler: {exec_info}")
        return f"failed to run extract handler: {exec_info}"
    finally:
        graph_handler.obs_client.close()
        log.info(f'start to clean local files in {graph_handler.download_dir}')
        log.info(
            f'start to clean local files in {graph_handler.download_dir}')
        graph_handler.clean_local_files(graph_handler.download_dir)
        log.info(
            f'succeeded to clean local files in {graph_handler.download_dir}')


class GraphHandler:

    def __init__(self, context):
        self.logger = context.getLogger()
        obs_endpoint = context.getUserData("obs_endpoint")
        self.obs_client = new_obs_client(context, obs_endpoint)
        self.target_bucket = context.getUserData("target_bucket")
        self.target_prefix = context.getUserData("target_prefix", "")
        self.download_dir = gen_local_download_path()
        self.accessKey = context.getAccessKey()
        self.secretKey = context.getSecretKey()
        self.token = context.getToken()
        self.region = context.getUserData("region")
        self.graph_ip = context.getUserData("graph_ip")
        self.graph_port = context.getUserData("graph_port")
        self.project_id = context.getProjectID()
        self.graph_name = context.getUserData("graph_name")
        self.extract_dir = ""

    def run(self, record):
        (bucket, object_key) = get_obs_obj_info(record)
        object_key = unquote_plus(object_key)
        self.logger.info("src bucket:" + bucket)
        self.logger.info("src object_key:" + object_key)
        (path, file) = os.path.split(object_key)
        task_id = file.split("-")[0]
        download_path = self.download_dir + "/" + file
        download_result = self.download_file_from_obs(bucket, object_key,
                                                      download_path)
        if not download_result:
            return "FAILED"
        (filename, ext) = os.path.splitext(file)
        self.logger.info('start to extract file %s', download_path)
        self.extract_dir = self.download_dir + "/" + filename
        if not self.extract_download_file(download_path):
            return "FAILED"
        self.logger.info('succeeded to extract file %s', download_path)

        graph_dir = parse_csv_to_graph(self.extract_dir)
        self.logger.info('succeeded to parse csv to graph')

        self.upload_local_file(graph_dir, task_id)
        self.logger.info('succeeded to upload files in %s to obs', graph_dir)
        self.import_ges(task_id)
        self.logger.info('succeeded to import graph files to ges')
        clear()

        return 'SUCCESS'

    def download_file_from_obs(self, bucket, obj_name, download_path):
        self.logger.info(f'start to download object %s from obs %s to local %s',
                         obj_name, bucket, download_path)
        try:
            resp = self.obs_client.getObject(bucket, obj_name,
                                             downloadPath=download_path)
            if resp.status < 300:
                self.logger.info(
                    f'succeeded to download object %s from obs %s to local %s',
                    obj_name, bucket, download_path)
                return True
            else:
                self.logger.error(
                    f"failed to download object {obj_name} from obs {bucket}, "
                    f"errorCode:{resp.errorCode} errorMessage:{resp.errorMessage}")
        except:
            self.logger.error(
                f"failed to download file {obj_name} from obs bucket{bucket}, "
                f"exp:{traceback.format_exc()}")

    def extract_download_file(self, download_path):
        with zipfile.ZipFile(download_path, 'r') as f:
            for fn in f.namelist():
                extracted_path = Path(f.extract(fn, self.extract_dir))
            return True

    def upload_local_file(self, directory, task_id):
        for filename in os.listdir(directory):
            file_path = os.path.join(directory, filename)
            if os.path.isfile(file_path):
                object_key = self.get_object_key(file_path, task_id)
                self.upload_file_to_obs(object_key, file_path)
            elif os.path.isdir(file_path):
                self.upload_local_file(file_path, task_id)

    def get_object_key(self, file_path, task_id):
        object_key = task_id + "/" + file_path
        if self.target_prefix != "":
            object_key = self.target_prefix + "/" + task_id + file_path
        object_key = object_key.replace(self.extract_dir, "")
        if object_key.startswith("/"):
            object_key = object_key[1:]
        return object_key

    def upload_file_to_obs(self, object_key, file_path):
        try:
            resp = self.obs_client.putFile(self.target_bucket, object_key,
                                           file_path)
            if resp.status > 300:
                self.logger.error(f"failed to upload file {file_path} to "
                                  f"obs bucket {self.target_bucket}, "
                                  f"errorCode:{resp.errorCode} "
                                  f"errorMessage:{resp.errorMessage}")
                return None
        except:
            self.logger.error(f"failed to upload file {file_path} to "
                              f"obs bucket {self.target_bucket}, "
                              f"exp:{traceback.format_exc()}")

    def clean_local_files(self, file_path):
        if os.path.isfile(file_path):
            self.delete_local_file(file_path)
        elif os.path.isdir(file_path):
            for filename in os.listdir(file_path):
                sub_file_path = os.path.join(file_path, filename)
                if os.path.isfile(sub_file_path):
                    self.delete_local_file(sub_file_path)
                elif os.path.isdir(sub_file_path):
                    self.clean_local_files(sub_file_path)
                    os.rmdir(sub_file_path)

    def delete_local_file(self, file_path):
        try:
            os.remove(file_path)
        except:
            self.logger.error(
                f"failed to delete local file {file_path}, "
                f"exp:{traceback.format_exc()}")

    def import_ges(self, task_id):
        url = f"https://{self.graph_ip}:{self.graph_port}/ges/v1.0/" \
              f"{self.project_id}/graphs/{self.graph_name}/action?" \
              f"action_id=import-graph"
        edges_path = self.target_bucket + "/" + self.target_prefix + \
                     "/" + task_id + "/graph/edge"
        vertices_path = self.target_bucket + "/" + self.target_prefix + \
                        "/" + task_id + '/graph/vertex'
        schema_path = self.target_bucket + "/" + self.target_prefix + \
                      "/" + task_id + "/graph/schema.xml"
        payload = {
            "edgesetPath": edges_path,
            "edgesetFormat": "csv",
            "vertexsetPath": vertices_path,
            "vertexsetFormat": "csv",
            "schemaPath": schema_path,
            "obsParameters": {
                "accessKey": self.accessKey,
                "secretKey": self.secretKey,
                "region": self.region
            }
        }
        headers = {
            'Content-Type': 'application/json',
            'X-Auth-Token': self.token
        }
        response = requests.request("POST", url, headers=headers,
                                    data=json.dumps(payload), verify=False)
        self.logger.info("import graph files to ges result: %s",
                         response.text.encode('utf8'))


def new_obs_client(context, obs_server):
    ak = context.getAccessKey()
    sk = context.getSecretKey()
    return ObsClient(access_key_id=ak, secret_access_key=sk, server=obs_server)


def get_obs_obj_info(record):
    if 's3' in record:
        s3 = record['s3']
        return s3['bucket']['name'], s3['object']['key']
    else:
        obs_info = record['obs']
        return obs_info['bucket']['name'], obs_info['object']['key']


def gen_local_download_path():
    download_dir = LOCAL_MOUNT_PATH
    return download_dir
